import argparse
import time

def parse_args():
    parser = argparse.ArgumentParser(description="Run Julia branch and bound with customizable parameters.")
    parser.add_argument('--k', type = int, default = 3, help='Specify the parameter k (default: 3)')
    parser.add_argument('--time_limit', type=int, default = 600, help='Specify the time limit (in seconds, default: 800)')
    parser.add_argument('--filepath', type = str, default = 'Matrix_CovColon_txt', help = 'File path of the dataset')
    parser.add_argument('--epsilon', type = float, default = 3, help = 'value of the threshold we use in the algorithm')
    return parser.parse_args()


import numpy as np
from scipy.linalg import sqrtm


import random

def sum_of_squares(elements):
    return sum(x**2 for x in elements)

def calculate_median_of_sums(matrix, alpha):
    d = matrix.shape[0]  # dimension of the matrix
    num_lists = int(d**alpha)  # calculate d^alpha and convert to int
    
    # Flatten the matrix and shuffle its elements
    elements = list(matrix.flatten())
    random.shuffle(elements)
    
    # Split elements into num_lists parts
    lists = [elements[i::num_lists] for i in range(num_lists)]
    
    # Calculate sum of squares for each list
    mean_sum_squares = [sum_of_squares(lst) / len(lst) for lst in lists]
    
    # Calculate and return the median of sum_squares
    return np.median(mean_sum_squares)

def threshold_matrix(matrix, threshold):
    """ Thresholds the matrix by setting values below the threshold (in absolute terms) to zero. """
    return np.where(np.abs(matrix) < threshold, 0, matrix)

def find_supp(x, ind):
    # Identify the non-zero elements of best_x
    non_zero_indices = np.nonzero(x)[0]
    
    # Intersect the indices where best_x is non-zero with best_ind
    support_indices = [ind[index] for index in non_zero_indices]
    
    return support_indices


def find_block_diagonals(A, matrix):
    def dfs(node, visited, component):
        stack = [node]
        while stack:
            v = stack.pop()
            if not visited[v]:
                visited[v] = True
                component.append(v)
                for neighbor in range(d):
                    if matrix[v, neighbor] != 0 and not visited[neighbor]:
                        stack.append(neighbor)

    d = matrix.shape[1]
    visited = [False] * d
    components = []

    for i in range(d):
        if not visited[i]:
            component = []
            dfs(i, visited, component)
            components.append(component)

    buckets = {}
    for component in components:
        if component:
            root = component[0]
            buckets[root] = component

    block_diagonals = []
    indices = []
    d_star = 0

    for bucket in buckets.values():
        matrix_bucket = A[np.ix_(bucket, bucket)]
        if np.all(matrix_bucket == 0):
            continue
        block_diagonals.append(matrix_bucket)
        indices.append(bucket)
        d_star = max(d_star, len(bucket))

    return block_diagonals, indices, d_star




def solve_bd_spca(A, k, block_diagonals, indices, time_limit):
    # The input A is the original
    # Block_diagonals are the list of blocks in thresholded A
    # indices are the list of indices corresponding to the original A
    time_total = 0

    obj_best = 0
    ind_best = [0]
    x_best = np.zeros((1,1))
    

    for i in range(len(block_diagonals)):
        # print(f"This is the {i}-th block diagonal matrix.")
        bd = block_diagonals[i]
        if np.all(bd == 0):
            continue
        if bd.shape[0] < k:
            start_time = time.time()
            eigenvalues, eigenvectors = np.linalg.eig(bd)
            # Find the index of the maximum eigenvalue
            index_max_eigenvalue = np.argmax(eigenvalues)
            # Eigenvalue
            max_eigenvalue = eigenvalues[index_max_eigenvalue]
            # Corresponding eigenvector
            max_eigenvector = eigenvectors[:, index_max_eigenvalue]
            if max_eigenvalue > obj_best:
                obj_best = max_eigenvalue
                ind_best = indices[i]
                x_best = max_eigenvector
            block_time = time.time() - start_time
            # print(f"The runtime for this block is {block_time}s")
            time_total = time_total + block_time
            continue
        # Convert numpy arrays to Julia arrays
        Main.Sigma = Main.eval("Array{Float64}")(bd)
        Main.data = Main.eval("Array{Float64}")(np.real(sqrtm(bd)))
        # Create an instance of the problem struct
        Main.eval('prob = problem(data, Sigma)')
        
        start_time = time.time()
        try:
            results = Main.eval(f"branchAndBound(prob, {k}, timeCap={time_limit})")
            obj, xVal, timetoBound, timetoConverge, timeOut, explored, toPrint, finalGap = results
            if np.real(obj) > obj_best:
                obj_best = np.real(obj)
                ind_best = indices[i]
                x_best = xVal
        except Exception as e:
            print("An error occurred:", str(e))
        block_time = time.time() - start_time
        # print(f"The runtime for this block is {block_time}s")
        time_total = time_total + block_time
        
    start_time = time.time()
    original_obj = x_best.T @ A[np.ix_(ind_best, ind_best)] @ x_best
    time_total = time_total + time.time() - start_time

    # print(f"Best opt found is {obj_best}.")
    # print(f"original approximation of opt found is {original_obj}.")
    # print(f"Total runtime is {time_total}.")
    
    return x_best, ind_best, obj_best, original_obj, time_total

def solve_bd_spca_bs(A, k, initial_threshold, a = 0.1, b = 10, max_d = 40, tol = 5e-2, time_limit = 600):
    # In this function, we call solve_bd_spca many times
    start_time = time.time()
    total_time = 0
    U = b * initial_threshold
    L = a * initial_threshold
    
    S = threshold_matrix(A, U)
    block_diagonals, indices, d_star = find_block_diagonals(A, S)
    
    print("Now solving the first spca instance.")
    best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, max(time_limit - time.time() + start_time, 0))
    
    i = 2
    # record the max d_star visited within computational constraints
    max_d_star = 0
    while U - L > tol:
        if d_star <= max_d:
            max_d_star = max(d_star, max_d_star)
        # d_star_old = d_star
        best_obj_old = best_obj
        
        M = (U + L) / 2
        start_sorting_time = time.time()
        S = threshold_matrix(A, M)
        block_diagonals, indices, d_star = find_block_diagonals(A, S)
        
        if d_star <= max_d_star:
            # results are the same or worse
            # it should be d_star >= d_star_old
            U = M
            print(f"Current threshold is {M}, and gives the same d_star.\n")
            continue
        
        if d_star > max_d:
            # We cannot afford such computation
            L = M
            print(f"Current threshold is {M}, d_star is {d_star}, and exceeds computational resource.\n")
            continue
        
        print(f"Now running block diagonal spca for threshold {M}, with d_star being {d_star}.")
        print(f"This is the {i}-th spca instance.")
        i = i + 1
        
        sorting_time = time.time() - start_sorting_time
        
        # Else, we know that we can afford the computation, and the result is going to be potentially better
        best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, time_limit)
        print(f"Best obj found is {best_obj}. The runtime for this instance is {time_passed + sorting_time}.")
        print(f"Best index set found is {best_ind}.")
        
        supp = find_supp(best_x, best_ind)
        print(f'Current support is {supp}.')
        if supp:
            D1, V1 = np.linalg.eigh(A[supp][:,supp]);
            y1 = V1[:, -1];
            better_PC_value = y1.T @ A[supp][:,supp] @ y1;
            print(f"Better obj found is {better_PC_value}.")
        
        total_time = time.time() - start_time
        print(f"The total runtime is {total_time}.")
        
        if total_time >= time_limit:
            print("Time limit reached.\n")
            break
        
        if abs(best_obj - best_obj_old) < 1e-2:
            print("Unchanged objective value detected.")
            U = M
        print("\n")
    
    total_time = time.time() - start_time
    print(f"Best opt found is {best_obj}.")
    print(f"Total runtime is {total_time}.")
    print("\n")
    
    return best_x, best_ind, temp_obj, best_obj, total_time, M
    
from sklearn.metrics import jaccard_score

def evaluate_thresholded_spca(dataset, k, threshold, time_limit=600):
    # Load the dataset and process the matrix A
    A = np.genfromtxt(dataset, delimiter=',')
    n, d = A.shape
    if n != d:
        A = np.cov(A.T)  # Ensure the matrix is square
        
    # Compute the infinity norm of the matrix
    inf_norm = np.max(np.abs(A))
    epsilon_ratio = threshold / inf_norm
    print(f"Infinity norm of the matrix: {inf_norm}")
    print(f"Ratio of epsilon to infinity norm: {epsilon_ratio:.4f}")

    # Apply thresholding to the matrix
    thresholded_matrix = threshold_matrix(A, threshold)
    zero_proportion = np.sum(thresholded_matrix == 0) / (d * d)

    print(f"Proportion of zero entries after thresholding: {zero_proportion:.4f}")

    # Find block diagonals
    block_diagonals, indices, d_star = find_block_diagonals(A, thresholded_matrix)

    # Run the baseline Branch-and-Bound method
    print("Running baseline Branch-and-Bound...")
    Main.Sigma = Main.eval("Array{Float64}")(A)
    Main.data = Main.eval("Array{Float64}")(np.real(sqrtm(A)))
    Main.eval('prob = problem(data, Sigma)')
    results = Main.eval(f"branchAndBound(prob, {k}, timeCap={time_limit})")
    baseline_obj, baseline_xVal, _, _, _, _, _, _ = results
    baseline_support = np.nonzero(baseline_xVal)[0]
    print(f"Baseline objective: {baseline_obj}")

    # Run the block-diagonal method
    print("Running Block-Diagonal SPCA...")
    best_x, best_ind, temp_obj, best_obj, total_time = solve_bd_spca(
        A, k, block_diagonals, indices, time_limit
    )
    block_support = find_supp(best_x, best_ind)
    print(f"Block-Diagonal objective: {best_obj}")

    # Compute better_PC_value for block-diagonal method
    if len(block_support) > 0:
        A_block = A[np.ix_(block_support, block_support)]
        _, V_block = np.linalg.eigh(A_block)
        better_PC_value_block = V_block[:, -1].T @ A_block @ V_block[:, -1]
        print(f"Better PC value for block-diagonal method: {better_PC_value_block}")
    else:
        better_PC_value_block = None
        print("Block-diagonal support is empty. Cannot compute better PC value.")

    # Calculate Jaccard index for support comparison
    if len(baseline_support) > 0 and len(block_support) > 0:
        # Convert support to binary arrays of length d
        baseline_support_vec = np.zeros(d, dtype=int)
        baseline_support_vec[baseline_support] = 1
        block_support_vec = np.zeros(d, dtype=int)
        block_support_vec[block_support] = 1

        jaccard_idx = jaccard_score(baseline_support_vec, block_support_vec)
        print(f"Jaccard index of supports: {jaccard_idx:.4f}")
    else:
        jaccard_idx = None
        print("One or both supports are empty. Cannot compute Jaccard index.")

    return {
        "threshold": threshold,
        "epsilon_ratio": epsilon_ratio,
        "zero_proportion": zero_proportion,
        "baseline_obj": baseline_obj,
        "baseline_support": baseline_support.tolist(),
        "block_obj": best_obj,
        "block_support": block_support,
        "better_PC_value_block": better_PC_value_block,
        "jaccard_index": jaccard_idx
    }

import pandas as pd
import json
import os

def process_evaluation(data, txt_dir, csv_output, time_limit=600):
    """
    Processes the evaluation function for each row in the data and saves results to separate .txt files and a summary .csv.
    Args:
        data (pd.DataFrame): DataFrame containing dataset, k, and epsilon values.
        txt_dir (str): Directory to save individual .txt files for each dataset.
        csv_output (str): Path to the output .csv file for summarized results.
        time_limit (int): Time limit for the evaluation function.
    Returns:
        pd.DataFrame: Summary of results.
    """
    # Create directory for txt files if it doesn't exist
    os.makedirs(txt_dir, exist_ok=True)
    
    # List to hold summary results for CSV
    summary_results = []
    
    for idx, row in data.iterrows():
        dataset = row['dataset']
        k = row['k']
        epsilon = row['epsilon']
        txt_file_path = os.path.join(txt_dir, f"{os.path.splitext(dataset)[0]}_k{k}_epsilon{epsilon:.2f}.txt")
        
        try:
            print(f"Processing dataset: {dataset}, k: {k}, epsilon: {epsilon}")
            results = evaluate_thresholded_spca(dataset, k, epsilon, time_limit)
            
            # Write detailed results to a separate TXT file
            with open(txt_file_path, "w") as txt_file:
                txt_file.write(f"Dataset: {dataset}, k: {k}, Epsilon: {epsilon}\n")
                txt_file.write(json.dumps(results, indent=4))
            
            # Append summary results for the CSV file
            summary_results.append({
                "dataset": dataset,
                "k": k,
                "epsilon": epsilon,
                "epsilon_ratio": results.get("epsilon_ratio", None),
                "zero_proportion": results.get("zero_proportion", None),
                "baseline_obj": results.get("baseline_obj", None),
                "block_obj": results.get("block_obj", None),
                "jaccard_index": results.get("jaccard_index", None),
                "better_PC_value_block": results.get("better_PC_value_block", None)
            })
        except Exception as e:
            # Handle and log errors for any dataset
            print(f"Error processing Dataset: {dataset}, k: {k}, Epsilon: {epsilon}")
            with open(txt_file_path, "w") as txt_file:
                txt_file.write(f"Error processing Dataset: {dataset}, k: {k}, Epsilon: {epsilon}\n")
                txt_file.write(str(e))
            
    # Save the summary results to a CSV file
    summary_df = pd.DataFrame(summary_results)
    summary_df.to_csv(csv_output, index=False)
    return summary_df

from scipy.linalg import block_diag

def generate_positive_definite_blocks(num_blocks=10, block_size=10, noise_scale=0.1):
    """Generate a large block-diagonal positive definite matrix A with Gaussian noise E."""
    blocks = []
    n = 100 # number of samples to generate the Wishart matrix
    for _ in range(num_blocks):
        # Generate a random positive definite matrix
        B = np.random.normal(loc = 0, scale = 1, size = (n, block_size))
        B = B.T @ B / n  # make PSD whp
        blocks.append(B)
    
    # Create the block-diagonal matrix
    A = block_diag(*blocks)
    
    # Genearte Guassian noise
    E = np.zeros_like(A)
    upper_indices = np.triu_indices_from(E)
    E[upper_indices] = np.random.normal(scale=noise_scale, size=len(upper_indices[0]))
    E = (E + E.T) - np.diag(np.diag(E))  # Ensure symmetry and clear diagonal
    
    return A, E

def perform_spca_with_noise(num_blocks=10, block_size=10, k=5, alpha=0.7, noise_scale=0.1, time_limit=600):
    """Perform SPCA on A and A+E with noise, block-diagonal structure, and epsilon thresholding."""
    # Step 1: Generate A and symmetric noise E
    A, E = generate_positive_definite_blocks(num_blocks, block_size, noise_scale)
    A_noisy = A + E
    
    d = A.shape[0]
    
    start_time = time.time()
    
    # Run the baseline Branch-and-Bound method
    print("Running baseline Branch-and-Bound...")
    Main.Sigma = Main.eval("Array{Float64}")(A)
    Main.data = Main.eval("Array{Float64}")(np.real(sqrtm(A)))
    Main.eval('prob = problem(data, Sigma)')
    results = Main.eval(f"branchAndBound(prob, {k}, timeCap={time_limit})")
    baseline_obj, baseline_xVal, _, _, _, _, _, _ = results
    baseline_support = np.nonzero(baseline_xVal)[0]
    print(f"Baseline objective: {baseline_obj}")
    
    time_BB = time.time() - start_time
    
    start_time = time.time()
    # Estimate epsilon using calculate_median_of_sums
    sigma_square = calculate_median_of_sums(A_noisy, alpha)
    epsilon = 2 * np.sqrt(3 * sigma_square * np.log(d))
    print(f"Estimated epsilon: {epsilon}")
    
    # Threshold A+E using epsilon
    A_noisy_thresholded = threshold_matrix(A_noisy, epsilon)
    
    # Find block-diagonal matrices for thresholded A+E
    print("Finding block-diagonal matrices for thresholded A+E...")
    block_diagonals_AE, indices_AE, d_star = find_block_diagonals(A_noisy, A_noisy_thresholded)
    print(f"Largest block size = {d_star}")
    
    # Solve SPCA on thresholded A+E
    print("Solving SPCA on thresholded A+E...")
    x_AE, ind_AE, obj_AE, orig_obj_AE, time_AE = solve_bd_spca(A_noisy, k, block_diagonals_AE, indices_AE, time_limit)
    print(f"SPCA on thresholded A+E completed. Objective: {obj_AE}, Time: {time_AE}s")
    
    block_support = find_supp(x_AE, ind_AE)
    
    print(f"x_AE shape: {x_AE.shape}, block_support length: {len(block_support)}")


    # Compute better_PC_value for block-diagonal method
    if len(ind_AE) > 0:
        A_block = A[np.ix_(block_support, block_support)]
        _, V_block = np.linalg.eigh(A_block)
        original_PC_value = x_AE.T @ A[np.ix_(ind_AE, ind_AE)] @ x_AE
        print(f"PC value for the original matrix A is {original_PC_value}")
        better_PC_value_block = V_block[:, -1].T @ A_block @ V_block[:, -1]
        print(f"Better PC value for block-diagonal method: {better_PC_value_block}")
    else:
        better_PC_value_block = None
        print("Block-diagonal support is empty. Cannot compute better PC value.")
        
    time_BD = time.time() - start_time

    # Calculate Jaccard index for support comparison
    if len(baseline_support) > 0 and len(block_support) > 0:
        # Convert support to binary arrays of length d
        baseline_support_vec = np.zeros(d, dtype=int)
        baseline_support_vec[baseline_support] = 1
        block_support_vec = np.zeros(d, dtype=int)
        block_support_vec[block_support] = 1

        jaccard_idx = jaccard_score(baseline_support_vec, block_support_vec)
        print(f"Jaccard index of supports: {jaccard_idx:.4f}")
    else:
        jaccard_idx = None
        print("One or both supports are empty. Cannot compute Jaccard index.")

    return {
        "threshold": epsilon,
        "epsilon_ratio": epsilon / np.max(np.abs(A + E)),
        "largest_block_size": d_star,
        "zero_proportion": np.sum(A_noisy_thresholded == 0) / (d * d),
        "baseline_obj": baseline_obj,
        "baseline_support": baseline_support.tolist(),
        "block_obj": original_PC_value,
        "block_noisy_obj": obj_AE,
        "block_support": block_support,
        "better_PC_value_block": better_PC_value_block,
        "jaccard_index": jaccard_idx,
        "BD_time": time_BD,
        "Baseline_time": time_BB
    }


def random_testing(num_blocks=10, block_size=10, alpha=0.7, noise_scale=0.1, time_limit=600, output_csv="random_testing_results.csv"):
    """Perform random testing for various k values and summarize results in a CSV file."""
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Define the values of k to test
    k_values = [2, 3, 5, 7, 10]
    num_trials = 10  # Number of trials per k
    
    # List to hold results
    results = []
    
    for k in k_values:
        print(f"Testing for k = {k}")
        for trial in range(1, num_trials + 1):
            print(f"  Trial {trial}/{num_trials}")
            try:
                # Run the SPCA testing function
                result = perform_spca_with_noise(
                    num_blocks=num_blocks,
                    block_size=block_size,
                    k=k,
                    alpha=alpha,
                    noise_scale=noise_scale,
                    time_limit=time_limit
                )
                
                # Add k and trial information to the result
                result["k"] = k
                result["trial"] = trial
                
                # Append to results list
                results.append(result)
            except Exception as e:
                print(f"Error during trial {trial} for k = {k}: {e}")
                # Record the error in the results
                results.append({
                    "k": k,
                    "trial": trial,
                    "error": str(e)
                })
    
    # Convert results to a DataFrame
    results_df = pd.DataFrame(results)
    
    # Save to CSV
    results_df.to_csv(output_csv, index=False)
    print(f"Results saved to {output_csv}")
    
    return results_df



args = parse_args()



julia_path = '/mnt/ws/home/dzhou/julia/julia-1.6.7/bin/julia'
exists = os.path.exists(julia_path)
executable = os.access(julia_path, os.X_OK)

print("Julia path exists:", exists)
print("Julia is executable:", executable)

if not exists or not executable:
    print("Check the path to Julia or permissions.")
else:
    from julia.api import Julia
    jl = Julia(runtime=julia_path, compiled_modules=False)

time_start_loading = time.time()
print("Importing Main...")
from julia import Main
print("Including utilities.jl...")
Main.include("utilities.jl")
print("Including branchAndBound.jl...")
Main.include("branchAndBound.jl")
time_loading = time.time() - time_start_loading
print(f"Loading time for julia package(s) is {time_loading}s.")

k = args.k
time_limit = args.time_limit
file_path = args.filepath
epsilon = args.epsilon



results = random_testing(num_blocks=30, block_size=20, alpha=0.7, noise_scale=0.1, time_limit=600, output_csv="synthetic_results_new_alg.csv")
print(results.head())



    


    
    






